
import logging
from typing import Any, List, Mapping, Optional, Set, Dict

import requests
import json
from langchain_core.language_models.llms import LLM
from langchain_core.pydantic_v1 import Field

from langchain_community.llms.utils import enforce_stop_tokens

logger = logging.getLogger(__name__)


class Yuan2(LLM):
    """Yuan2.0 language models.

    Example:
        .. code-block:: python

            yuan_llm = Yuan2(infer_api="http://127.0.0.1:8900/yuan", max_tokens=1024, temp=1.0, top_p=0.9, top_k=40)
            print(yuan_llm)
            print(yuan_llm("你是谁？"))
    """
    infer_api: str
    """Inference api"""

    max_tokens: int = Field(1024, alias="max_token")
    """Token context window."""

    temp: Optional[float] = 0.7
    """The temperature to use for sampling."""

    top_p: Optional[float] = 0.9
    """The top-p value to use for sampling."""

    top_k: Optional[int] = 40
    """The top-k value to use for sampling."""

    do_sample: bool = False
    """The do_sample is a Boolean value that determines whether to use the sampling method during text generation."""

    echo: Optional[bool] = False
    """Whether to echo the prompt."""

    stop: Optional[List[str]] = []
    """A list of strings to stop generation when encountered."""

    repeat_last_n: Optional[int] = 64
    "Last n tokens to penalize"

    repeat_penalty: Optional[float] = 1.18
    """The penalty to apply to repeated tokens."""

    streaming: bool = False
    """Whether to stream the results or not."""

    history: List[str] = []
    """History of the conversation"""

    use_history: bool = False
    """Whether to use history or not"""

    @property
    def _llm_type(self) -> str:
        return "Yuan2.0"

    @staticmethod
    def _model_param_names() -> Set[str]:
        return {
            "max_tokens",
            "temp",
            "top_k",
            "top_p",
            "do_sample",
        }

    def _default_params(self) -> Dict[str, Any]:
        return {
            "infer_api": self.infer_api,
            "max_tokens": self.max_tokens,
            "temp": self.temp,
            "top_k": self.top_k,
            "top_p": self.top_p,
            "do_sample": self.do_sample,
            "use_history": self.use_history
        }

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {
            "model": self._llm_type,
            **self._default_params(),
            **{
                k: v for k, v in self.__dict__.items() if k in self._model_param_names()
            },
        }

    def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
        """Call out to a Yuan2.0 LLM inference endpoint.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

                response = yuan_llm("Who are you?")
        """
        if self.use_history:
            input = "\n".join(self.history) + '\n' + prompt
            self.history.append(prompt)
        else:
            input = prompt

        headers = {
            'Content-Type': 'application/json'
        }
        data = json.dumps({
            "ques_list":[
                {
                    "id": "000",
                    "ques": input
                }
            ],
            "tokens_to_generate": self.max_tokens,
            "temperature": self.temp,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "do_sample": self.do_sample
        })

        logger.debug("prompt:", input)

        # call api
        try:
            response = requests.put(self.infer_api, headers=headers, data=data)
        except requests.exceptions.RequestException as e:
            raise ValueError(f"Error raised by inference api: {e}")

        logger.debug(f"Yuan2.0 response: {response}")

        if response.status_code != 200:
            raise ValueError(f"Failed with response: {response}")
        try:
            resp = response.json()

            if resp["errCode"] != "0":
                raise ValueError(f"Failed with error code [{resp['errCode']}], error message: [{resp['errMessage']}]")

            if "resData" in resp:
                if len(resp["resData"]["output"]) >= 0:
                    generate_text = resp["resData"]["output"][0]["ans"]
                else:
                    raise ValueError("No output found in response.")
            else:
                raise ValueError("No resData found in response.")

        except requests.exceptions.JSONDecodeError as e:
            raise ValueError(
                f"Error raised during decoding response from inference api: {e}."
                f"\nResponse: {response.text}"
            )

        if stop is not None:
            generate_text = enforce_stop_tokens(generate_text, stop)

        # support multi-turn chat
        if self.use_history:
            self.history.append(generate_text)

        logger.debug(f"history: {self.history}")
        return generate_text.strip("<eod>")


def langchain_demo():
    yuan_llm = Yuan2(infer_api="http://127.0.0.1:8900/yuan", max_tokens=1024, temp=1.0, top_p=0.9, top_k=40, use_history=False)
    print(yuan_llm)
    print(yuan_llm("What NFL team won the Super Bowl in the year Justin Beiber was born? Let's think step by step"))


"""Test ChatGLM API wrapper."""
from langchain.schema import LLMResult


def test_yuan2_call() -> None:
    """Test valid call to Yuan2.0."""
    llm = Yuan2(infer_api="http://127.0.0.1:8900/yuan", max_tokens=1024, temp=1.0, top_p=0.9, top_k=40, use_history=False)
    output = llm("写一段快速排序算法。")
    assert isinstance(output, str)


def test_yuan2_generate() -> None:
    """Test valid call to Yuan2.0 inference api."""
    llm = Yuan2(infer_api="http://127.0.0.1:8900/yuan", max_tokens=1024, temp=1.0, top_p=0.8, top_k=0, use_history=False)
    #output = llm.generate(["您是负责改进谷歌搜索结果的助理。生成三个类似于输入问题的谷歌搜索查询。输出应该是一个编号的问题列表，每个问题的末尾都应该有一个问号：what is vitamin？"])
    #output = llm.generate(["You are an assistant tasked with improving Google search results. Generate THREE Google search queries that are similar to this question. The output should be a numbered list of questions and each should have a question mark at the end: what is vitamin？"])
    output = llm.generate(["what is vitamin？"])
    print(output)
    assert isinstance(output, LLMResult)
    assert isinstance(output.generations, list)

def test_yuan2_template() ->None:
    from langchain.chains import LLMChain
    from langchain.prompts import PromptTemplate

    llm = Yuan2(infer_api="http://127.0.0.1:8900/yuan", max_tokens=1024, temp=1.0, top_p=0.9, top_k=40,
                use_history=False)
    template = """{question}"""
    prompt = PromptTemplate(template=template, input_variables=["question"])
    llm_chain = LLMChain(prompt=prompt, llm=llm)

    question = "写一段快速排序算法"
    print(llm_chain.run(question))


# test_yuan2_generate()
